import jax.numpy as np 
import abc
import scalevi.distributions.distributions_branched_base as dist_branched_base


class Model(abc.ABC):
    @abc.abstractmethod
    def __init__(self):
        pass
    @abc.abstractmethod
    def log_prob(self, z, params, chunk, **kwargs):
        pass

class ModelBranched(dist_branched_base.BranchDist):

    def sample(self, rng_key, params, chunk):
        raise NotImplementedError
    def sample_child(self, rng_key, params, θ, chunk):
        raise NotImplementedError
    def sample_parent(self, rng_key, params):
        raise NotImplementedError


